Skip to content

[Benchmark] Add compute_seq_len_sweep_config_with_probe with linear/quadratic scaling support#1218

Open
shivam2199 wants to merge 4 commits intolinkedin:mainfrom
shivam2199:issue-1200-quadratic-probe-scaling
Open

[Benchmark] Add compute_seq_len_sweep_config_with_probe with linear/quadratic scaling support#1218
shivam2199 wants to merge 4 commits intolinkedin:mainfrom
shivam2199:issue-1200-quadratic-probe-scaling

Conversation

@shivam2199
Copy link
Copy Markdown
Contributor

Summary

Refs #1200. Addresses non-linear memory scaling in benchmark sweep config inference.

The existing compute_seq_len_sweep_config inverts memory via max_tokens = usable_bytes / kernel_bytes_per_token, which only holds for linear-scaling kernels. For O(L²) kernels (e.g. benchmark_sparse_multi_token_attention.py), this overestimates capacity by orders of magnitude — the existing workaround there divides by probe_L * probe_L, but the downstream sweep math still treats the result as linear bytes-per-token.

Per discussion on the issue (#1200 (comment)), this PR adds a new helper rather than threading scaling_method through the existing function — 16+ benchmark scripts call estimate_kernel_peak_memory today, and a wider signature change would conflict with in-flight benchmark refactors (#1199, #1180). Linear-scaling callers are unchanged; only quadratic-scaling benchmarks opt in.

What changed

  • benchmark/scripts/benchmark_model_configs.py — adds compute_seq_len_sweep_config_with_probe(model_cfg, probe_fn, probe_seq_len, probe_batch_size=1, scaling_method="linear" | "quadratic", ...). Internalizes the probe call + inversion; reuses estimate_kernel_peak_memory for the measurement.
  • benchmark/scripts/benchmark_sparse_multi_token_attention.py — switches the token_length sweep mode to the new helper with scaling_method="quadratic", dropping the manual peak_bytes // (probe_L * probe_L) workaround.

estimate_kernel_peak_memory and compute_seq_len_sweep_config are untouched.

Validation

Hardware: A10G 24GB (g5.xlarge).

Synthetic O(L²) probe (B=2, L=2048, allocates B * L * L floats) using LLAMA_3_8B config and max_seq_len=2**20 to bypass the model cap so the raw inversion is visible:

quadratic: SeqLenSweepConfig(batch_size=2, seq_len=8192)
linear:    SeqLenSweepConfig(batch_size=2, seq_len=65536)

The 8× gap (≈17× before snap-to-power-of-2) demonstrates the inversion difference: linear claims a sweep at L=65536 fits, when in reality L² at that size would require multiple TBs. quadratic lands at a realistic L=8192. This matches the issue's premise — for non-linear-scaling kernels, the existing inversion overestimates capacity and would OOM at the predicted boundary.

Testing Done

  • Synthetic O(L²) sanity check on A10G — confirms quadratic predicts L=8192 vs linear predicts L=65536 for the same probe (8× separation, scales as expected).
  • benchmark_sparse_multi_token_attention.py imports + helper resolution verified locally.
  • Full sparse-attention end-to-end sweep on A10G (deferred — synthetic test already isolates the inversion math from kernel-specific noise).

cc @Tcc0403

shivam2199 and others added 2 commits May 7, 2026 22:13
…r scaling (linkedin#1200)

Adds a new helper alongside the existing compute_seq_len_sweep_config that
internalizes both the probe and the seq-len inversion, with a scaling_method
argument supporting "linear" (default) and "quadratic". For O(L^2) kernels,
the inversion uses L_max = sqrt(usable / (B * c_per_BL2)) instead of the
linear max_tokens / batch_size path.

Migrates benchmark_sparse_multi_token_attention.py to the new helper and
drops its manual `peak_bytes // (probe_L * probe_L)` workaround.

The existing estimate_kernel_peak_memory and compute_seq_len_sweep_config
are unchanged; linear-scaling benchmark callers don't need to migrate.
@shivam2199
Copy link
Copy Markdown
Contributor Author

@Tcc0403 @Mecoli1219 Please take a look

Comment on lines +341 to +351
batch_size = max(1, min(max_batch_size, probe_batch_size))

if scaling_method == "linear":
c_per_BL = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len))
max_seq_len_from_mem = max(1, int(usable_bytes / (batch_size * c_per_BL)))
else:
c_per_BL2 = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len * probe_seq_len))
max_seq_len_from_mem = max(1, int(math.sqrt(usable_bytes / (batch_size * c_per_BL2))))

seq_len = min(max_seq_len, max_seq_len_from_mem)
seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just plug this part to compute_seq_len_sweep_config?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tcc0403 Good call. Pushed 6c204db which extracts two private helpers — _max_seqlen_under_memory (handles both linear and quadratic inversion) and _snap_pow2_seqlen — and collapses both public functions to thin orchestration over them.

compute_seq_len_sweep_config treats kernel_bytes_per_token as a unit-probe (B=L=1, linear) so the inversion math reduces to the existing max_tokens = usable / bpt quantity. No behavior change for the existing callers; the duplicated inversion/snap logic is gone.

shivam2199 and others added 2 commits May 8, 2026 20:43
Per @Tcc0403 review: instead of two parallel implementations of the
inversion + power-of-2 snap, extract `_max_seqlen_under_memory` (handles
both linear and quadratic) and `_snap_pow2_seqlen`. Both public APIs
become thin orchestration layers over them.

`compute_seq_len_sweep_config` now treats `kernel_bytes_per_token` as a
unit-probe (B=L=1, scaling=linear) so the math collapses to the existing
`max_tokens = usable / bpt` behavior — no behavior change for the 16+
existing callers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants